{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 8.1 命令式和符号式混合编程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "10"
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def add(a, b):\n",
    "    return a + b\n",
    "\n",
    "def fancy_func(a, b, c, d):\n",
    "    e = add(a, b)\n",
    "    f = add(c, d)\n",
    "    g = add(e, f)\n",
    "    return g\n",
    "\n",
    "fancy_func(1, 2, 3, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "\ndef add(a, b):\n    return a + b\n\ndef fancy_func(a, b, c, d):\n    e = add(a, b)\n    f = add(c, d)\n    g = add(e, f)\n    return g\n\nprint(fancy_func(1, 2, 3, 4))\n\n10\n"
    }
   ],
   "source": [
    "def add_str():\n",
    "    return '''\n",
    "def add(a, b):\n",
    "    return a + b\n",
    "'''\n",
    "\n",
    "def fancy_func_str():\n",
    "    return '''\n",
    "def fancy_func(a, b, c, d):\n",
    "    e = add(a, b)\n",
    "    f = add(c, d)\n",
    "    g = add(e, f)\n",
    "    return g\n",
    "'''\n",
    "\n",
    "def evoke_str():\n",
    "    return add_str() + fancy_func_str() + '''\n",
    "print(fancy_func(1, 2, 3, 4))\n",
    "'''\n",
    "\n",
    "prog = evoke_str()\n",
    "print(prog)\n",
    "y = compile(prog, '', 'exec')\n",
    "exec(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 8.1.2 tf.function 的使用\n",
    "\n",
    "### 8.1.2.1 基础"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "import traceback\n",
    "import contextlib\n",
    "\n",
    "# Some helper code to demonstrate the kinds of errors you might encounter.\n",
    "@contextlib.contextmanager\n",
    "def assert_raises(error_class):\n",
    "  try:\n",
    "    yield\n",
    "  except error_class as e:\n",
    "    print('Caught expected exception \\n  {}:'.format(error_class))\n",
    "    traceback.print_exc(limit=2)\n",
    "  except Exception as e:\n",
    "    raise e\n",
    "  else:\n",
    "    raise Exception('Expected {} to be raised but no error was raised!'.format(\n",
    "        error_class))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "<tf.Tensor: id=14, shape=(2, 2), dtype=float32, numpy=\narray([[2., 2.],\n       [2., 2.]], dtype=float32)>"
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@tf.function\n",
    "def add(a, b):\n",
    "    return a+b\n",
    "\n",
    "add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "<tf.Tensor: id=37, shape=(), dtype=float32, numpy=1.0>"
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v = tf.Variable(1.0)\n",
    "with tf.GradientTape() as tape:\n",
    "    result = add(v, 1.0)\n",
    "tape.gradient(result, v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "<tf.Tensor: id=63, shape=(3, 2), dtype=float32, numpy=\narray([[3., 3.],\n       [3., 3.],\n       [3., 3.]], dtype=float32)>"
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@tf.function\n",
    "def dense_layer(x, w, b):\n",
    "  return add(tf.matmul(x, w), b)\n",
    "\n",
    "dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))"
   ]
  },
  {
   "cell_type": "markdown",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "###  8.1.2.2 追踪与多态"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "Tracing with Tensor(\"a:0\", shape=(), dtype=int32)\ntf.Tensor(2, shape=(), dtype=int32)\n\nTracing with Tensor(\"a:0\", shape=(), dtype=float32)\ntf.Tensor(2.2, shape=(), dtype=float32)\n\nTracing with Tensor(\"a:0\", shape=(), dtype=string)\ntf.Tensor(b'aa', shape=(), dtype=string)\n\n"
    }
   ],
   "source": [
    "# Functions are polymorphic\n",
    "\n",
    "@tf.function\n",
    "def double(a):\n",
    "  print(\"Tracing with\", a)\n",
    "  return a + a\n",
    "\n",
    "print(double(tf.constant(1)))\n",
    "print()\n",
    "print(double(tf.constant(1.1)))\n",
    "print()\n",
    "print(double(tf.constant(\"a\")))\n",
    "print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "Obtaining concrete trace\nTracing with Tensor(\"a:0\", dtype=string)\nExecuting traced function\ntf.Tensor(b'aa', shape=(), dtype=string)\ntf.Tensor(b'bb', shape=(), dtype=string)\nUsing a concrete trace with incompatible types will throw an error\nCaught expected exception \n  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:\nTraceback (most recent call last):\n  File \"<ipython-input-3-c67b780c8118>\", line 10, in assert_raises\n    yield\n  File \"<ipython-input-8-5351d0a2eda2>\", line 8, in <module>\n    double_strings(tf.constant(1))\ntensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_90 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_90]\n"
    }
   ],
   "source": [
    "print(\"Obtaining concrete trace\")\n",
    "double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))\n",
    "print(\"Executing traced function\")\n",
    "print(double_strings(tf.constant(\"a\")))\n",
    "print(double_strings(a=tf.constant(\"b\")))\n",
    "print(\"Using a concrete trace with incompatible types will throw an error\")\n",
    "with assert_raises(tf.errors.InvalidArgumentError):\n",
    "  double_strings(tf.constant(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "Tracing with Tensor(\"x:0\", shape=(None,), dtype=int32)\ntf.Tensor([4 1], shape=(2,), dtype=int32)\nCaught expected exception \n  <class 'ValueError'>:\nTraceback (most recent call last):\n  File \"<ipython-input-3-c67b780c8118>\", line 10, in assert_raises\n    yield\n  File \"<ipython-input-9-9939c82c1507>\", line 9, in <module>\n    next_collatz(tf.constant([[1, 2], [3, 4]]))\nValueError: Python inputs incompatible with input_signature:\n  inputs: (\n    tf.Tensor(\n[[1 2]\n [3 4]], shape=(2, 2), dtype=int32))\n  input_signature: (\n    TensorSpec(shape=(None,), dtype=tf.int32, name=None))\n"
    }
   ],
   "source": [
    "@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))\n",
    "def next_collatz(x):\n",
    "  print(\"Tracing with\", x)\n",
    "  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)\n",
    "\n",
    "print(next_collatz(tf.constant([1, 2])))\n",
    "# We specified a 1-D tensor in the input signature, so this should fail.\n",
    "with assert_raises(ValueError):\n",
    "  next_collatz(tf.constant([[1, 2], [3, 4]]))"
   ]
  },
  {
   "cell_type": "markdown",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "### 8.1.2.3 追踪触发的时机\n",
    "\n",
    "### 8.1.2.4 输入参数的选择 Python or Tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "Tracing with num_steps = 10\nTracing with num_steps = 20\n"
    }
   ],
   "source": [
    "def train_one_step():\n",
    "  pass\n",
    "\n",
    "@tf.function\n",
    "def train(num_steps):\n",
    "  print(\"Tracing with num_steps = {}\".format(num_steps))\n",
    "  for _ in tf.range(num_steps):\n",
    "    train_one_step()\n",
    "\n",
    "train(num_steps=10)\n",
    "train(num_steps=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "Tracing with num_steps = Tensor(\"num_steps:0\", shape=(), dtype=int32)\n"
    }
   ],
   "source": [
    "train(num_steps=tf.constant(10))\n",
    "train(num_steps=tf.constant(20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "Traced with 1\nExecuted with 1\nExecuted with 1\nTraced with 2\nExecuted with 2\n"
    }
   ],
   "source": [
    "@tf.function\n",
    "def f(x):\n",
    "  print(\"Traced with\", x)\n",
    "  tf.print(\"Executed with\", x)\n",
    "\n",
    "f(1)\n",
    "f(1)\n",
    "f(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "Python side effect\nPython side effect\nPython side effect\n"
    }
   ],
   "source": [
    "external_list = []\n",
    "\n",
    "def side_effect(x):\n",
    "  print('Python side effect')\n",
    "  external_list.append(x)\n",
    "\n",
    "@tf.function\n",
    "def f(x):\n",
    "  tf.py_function(side_effect, inp=[x], Tout=[])\n",
    "\n",
    "f(1)\n",
    "f(1)\n",
    "f(1)\n",
    "assert len(external_list) == 3\n",
    "# .numpy() call required because py_function casts 1 to tf.constant(1)\n",
    "assert external_list[0].numpy() == 1"
   ]
  },
  {
   "cell_type": "markdown",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### 8.1.2.6 注意 Python 的状态"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "Value of external_var: 0\nValue of external_var: 0\nValue of external_var: 0\n"
    }
   ],
   "source": [
    "external_var = tf.Variable(0)\n",
    "@tf.function\n",
    "def buggy_consume_next(iterator):\n",
    "  external_var.assign_add(next(iterator))\n",
    "  tf.print(\"Value of external_var:\", external_var)\n",
    "\n",
    "iterator = iter([0, 1, 2, 3])\n",
    "buggy_consume_next(iterator)\n",
    "# This reuses the first value from the iterator, rather than consuming the next value.\n",
    "buggy_consume_next(iterator)\n",
    "buggy_consume_next(iterator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "train([(1, 1), (1, 1)]) contains 8 nodes in its graph\ntrain([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph\ntrain(<DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 5 nodes in its graph\ntrain(<DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 5 nodes in its graph\n"
    }
   ],
   "source": [
    "def measure_graph_size(f, *args):\n",
    "  g = f.get_concrete_function(*args).graph\n",
    "  print(\"{}({}) contains {} nodes in its graph\".format(\n",
    "      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))\n",
    "\n",
    "@tf.function\n",
    "def train(dataset):\n",
    "  loss = tf.constant(0)\n",
    "  for x, y in dataset:\n",
    "    loss += tf.abs(y - x) # Some dummy computation.\n",
    "  return loss\n",
    "\n",
    "small_data = [(1, 1)] * 2\n",
    "big_data = [(1, 1)] * 10\n",
    "measure_graph_size(train, small_data)\n",
    "measure_graph_size(train, big_data)\n",
    "\n",
    "measure_graph_size(train, tf.data.Dataset.from_generator(\n",
    "    lambda: small_data, (tf.int32, tf.int32)))\n",
    "measure_graph_size(train, tf.data.Dataset.from_generator(\n",
    "    lambda: big_data, (tf.int32, tf.int32)))"
   ]
  },
  {
   "cell_type": "markdown",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### 8.1.2.7 自动控制依赖"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "<tf.Tensor: id=417, shape=(), dtype=float32, numpy=10.0>"
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Automatic control dependencies\n",
    "\n",
    "a = tf.Variable(1.0)\n",
    "b = tf.Variable(2.0)\n",
    "\n",
    "@tf.function\n",
    "def f(x, y):\n",
    "  a.assign(y * b)\n",
    "  b.assign_add(x * a)\n",
    "  return a + b\n",
    "\n",
    "f(1.0, 2.0)  # 10.0"
   ]
  },
  {
   "cell_type": "markdown",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### 8.1.2.8 变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "WARNING:tensorflow:From /Users/minds/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\nInstructions for updating:\nIf using Keras pass *_constraint arguments to layers.\nCaught expected exception \n  <class 'ValueError'>:\nTraceback (most recent call last):\n  File \"<ipython-input-3-c67b780c8118>\", line 10, in assert_raises\n    yield\n  File \"<ipython-input-17-73e410646579>\", line 8, in <module>\n    f(1.0)\nValueError: in converted code:\n\n    <ipython-input-17-73e410646579>:3 f  *\n        v = tf.Variable(1.0)\n    /Users/minds/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:260 __call__\n        return cls._variable_v2_call(*args, **kwargs)\n    /Users/minds/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:254 _variable_v2_call\n        shape=shape)\n    /Users/minds/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/variables.py:65 getter\n        return captured_getter(captured_previous, **kwargs)\n    /Users/minds/miniconda3/envs/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py:413 invalid_creator_scope\n        \"tf.function-decorated function tried to create \"\n\n    ValueError: tf.function-decorated function tried to create variables on non-first call.\n\n"
    }
   ],
   "source": [
    "@tf.function\n",
    "def f(x):\n",
    "  v = tf.Variable(1.0)\n",
    "  v.assign_add(x)\n",
    "  return v\n",
    "\n",
    "with assert_raises(ValueError):\n",
    "  f(1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "tf.Tensor(2.0, shape=(), dtype=float32)\ntf.Tensor(4.0, shape=(), dtype=float32)\n"
    }
   ],
   "source": [
    "v = tf.Variable(1.0)\n",
    "\n",
    "@tf.function\n",
    "def f(x):\n",
    "  return v.assign_add(x)\n",
    "\n",
    "print(f(1.0))  # 2.0\n",
    "print(f(2.0))  # 4.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "tf.Tensor(2.0, shape=(), dtype=float32)\ntf.Tensor(4.0, shape=(), dtype=float32)\n"
    }
   ],
   "source": [
    "class C:\n",
    "  pass\n",
    "\n",
    "obj = C()\n",
    "obj.v = None\n",
    "\n",
    "@tf.function\n",
    "def g(x):\n",
    "  if obj.v is None:\n",
    "    obj.v = tf.Variable(1.0)\n",
    "  return obj.v.assign_add(x)\n",
    "\n",
    "print(g(1.0))  # 2.0\n",
    "print(g(2.0))  # 4.0"
   ]
  },
  {
   "cell_type": "markdown",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 8.1.3 AutoGraph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "[0.924424171 0.986872435 0.888964534 0.444825053 0.725420952]\n[0.727983654 0.756025612 0.710881948 0.417635918 0.620255947]\n[0.621830225 0.638730049 0.611229599 0.394936919 0.551306188]\n[0.552401066 0.564034224 0.544992 0.375608444 0.50149852]\n[0.502317607 0.510964513 0.496757448 0.358887583 0.463294864]\n[0.463937879 0.47069636 0.459563255 0.344233781 0.432765812]\n[0.433288246 0.438761801 0.42972818 0.33125186 0.407630116]\n[0.408065647 0.412617594 0.405094087 0.319645166 0.386458606]\n[0.386829 0.390693 0.384299189 0.309186041 0.36830318]\n[0.368623316 0.371957481 0.36643523 0.299696445 0.352506608]\n[0.352786958 0.355702698 0.350869715 0.291034758 0.338596642]\n[0.338844836 0.341423243 0.33714661 0.283086896 0.32622394]\n[0.326445729 0.328747422 0.324927628 0.275759637 0.315123737]\n[0.315323502 0.317394823 0.313955665 0.268976 0.305091113]\n[0.305272281 0.30714941 0.304031432 0.262671709 0.295964688]\n[0.29613 0.297841579 0.294997573 0.256792754 0.287615418]\n[0.287767023 0.289336115 0.286728024 0.251293242 0.279938519]\n[0.280078292 0.281523645 0.279120505 0.246133909 0.272848159]\n[0.272977531 0.274314642 0.272090882 0.241281047 0.266273022]\n[0.266393244 0.267635018 0.265569299 0.236705363 0.260153472]\n[0.260265559 0.261422813 0.259497255 0.232381389 0.254439056]\n[0.2545439 0.255625874 0.253825247 0.228286847 0.249086857]\n[0.24918519 0.250199705 0.248511061 0.22440207 0.244060069]\n[0.244152546 0.245106369 0.243518502 0.220709711 0.239326939]\n[0.239414126 0.240313053 0.238816366 0.217194393 0.234859884]\n[0.234942272 0.235791385 0.234377414 0.213842332 0.230634838]\n[0.230712846 0.231516585 0.230177969 0.210641295 0.226630658]\n[0.226704657 0.227466956 0.226197228 0.207580224 0.222828716]\n[0.22289902 0.223623335 0.222416759 0.204649225 0.219212502]\n[0.219279423 0.219968796 0.218820289 0.201839283 0.215767339]\n[0.215831131 0.216488302 0.21539335 0.199142292 0.212480143]\n[0.212541044 0.213168442 0.212122992 0.196550876 0.209339157]\n[0.209397405 0.209997207 0.208997652 0.194058314 0.206333846]\n[0.206389636 0.206963837 0.206006855 0.191658452 0.203454703]\n[0.203508198 0.204058558 0.203141257 0.189345658 0.20069316]\n"
    },
    {
     "data": {
      "text/plain": "<tf.Tensor: id=571, shape=(5,), dtype=float32, numpy=\narray([0.20074446, 0.20127262, 0.2003923 , 0.18711483, 0.19804138],\n      dtype=float32)>"
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Simple loop\n",
    "\n",
    "@tf.function\n",
    "def f(x):\n",
    "  while tf.reduce_sum(x) > 1:\n",
    "    tf.print(x)\n",
    "    x = tf.tanh(x)\n",
    "  return x\n",
    "\n",
    "f(tf.random.uniform([5]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "def tf__f(x):\n  do_return = False\n  retval_ = ag__.UndefinedReturnValue()\n  with ag__.FunctionScope('f', 'f_scope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as f_scope:\n\n    def get_state():\n      return ()\n\n    def set_state(_):\n      pass\n\n    def loop_body(x):\n      ag__.converted_call(tf.print, f_scope.callopts, (x,), None, f_scope)\n      x = ag__.converted_call(tf.tanh, f_scope.callopts, (x,), None, f_scope)\n      return x,\n\n    def loop_test(x):\n      return ag__.converted_call(tf.reduce_sum, f_scope.callopts, (x,), None, f_scope) > 1\n    x, = ag__.while_stmt(loop_test, loop_body, get_state, set_state, (x,), ('x',), ())\n    do_return = True\n    retval_ = f_scope.mark_return_value(x)\n  do_return,\n  return ag__.retval(retval_)\n\n"
    }
   ],
   "source": [
    "def f(x):\n",
    "  while tf.reduce_sum(x) > 1:\n",
    "    tf.print(x)\n",
    "    x = tf.tanh(x)\n",
    "  return x\n",
    "\n",
    "print(tf.autograph.to_code(f))"
   ]
  },
  {
   "cell_type": "markdown",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### 8.1.3.1 AutoGraph：条件分支"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_tf_cond(f, *args):\n",
    "  g = f.get_concrete_function(*args).graph\n",
    "  if any(node.name == 'cond' for node in g.as_graph_def().node):\n",
    "    print(\"{}({}) uses tf.cond.\".format(\n",
    "        f.__name__, ', '.join(map(str, args))))\n",
    "  else:\n",
    "    print(\"{}({}) executes normally.\".format(\n",
    "        f.__name__, ', '.join(map(str, args))))\n",
    "\n",
    "  print(\"  result: \",f(*args).numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "dropout(tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(10,), dtype=float32), True) executes normally.\n  result:  [0. 0. 2. 0. 2. 0. 2. 2. 2. 2.]\n"
    }
   ],
   "source": [
    "@tf.function\n",
    "def dropout(x, training=True):\n",
    "  if training:\n",
    "    x = tf.nn.dropout(x, rate=0.5)\n",
    "  return x\n",
    "\n",
    "test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "dropout(tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(10,), dtype=float32), tf.Tensor(True, shape=(), dtype=bool)) uses tf.cond.\n  result:  [0. 2. 0. 2. 2. 0. 2. 2. 0. 2.]\n"
    }
   ],
   "source": [
    "test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), tf.constant(True))"
   ]
  },
  {
   "cell_type": "markdown",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### 8.1.3.2 AutoGraph 与循环"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_dynamically_unrolled(f, *args):\n",
    "  g = f.get_concrete_function(*args).graph\n",
    "  if any(node.name == 'while' for node in g.as_graph_def().node):\n",
    "    print(\"{}({}) uses tf.while_loop.\".format(\n",
    "        f.__name__, ', '.join(map(str, args))))\n",
    "  elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):\n",
    "    print(\"{}({}) uses tf.data.Dataset.reduce.\".format(\n",
    "        f.__name__, ', '.join(map(str, args))))\n",
    "  else:\n",
    "    print(\"{}({}) gets unrolled.\".format(\n",
    "        f.__name__, ', '.join(map(str, args))))"
   ]
  },
  {
   "cell_type": "markdown",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### For 循环"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "for_in_range() gets unrolled.\n"
    }
   ],
   "source": [
    "@tf.function\n",
    "def for_in_range():\n",
    "  x = 0\n",
    "  for i in range(5):\n",
    "    x += i\n",
    "  return x\n",
    "\n",
    "test_dynamically_unrolled(for_in_range)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "for_in_tfrange() uses tf.while_loop.\n"
    }
   ],
   "source": [
    "@tf.function\n",
    "def for_in_tfrange():\n",
    "  x = tf.constant(0, dtype=tf.int32)\n",
    "  for i in tf.range(5):\n",
    "    x += i\n",
    "  return x\n",
    "\n",
    "test_dynamically_unrolled(for_in_tfrange)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "for_in_tfdataset() uses tf.data.Dataset.reduce.\n"
    }
   ],
   "source": [
    "@tf.function\n",
    "def for_in_tfdataset():\n",
    "  x = tf.constant(0, dtype=tf.int64)\n",
    "  for i in tf.data.Dataset.range(5):\n",
    "    x += i\n",
    "  return x\n",
    "\n",
    "test_dynamically_unrolled(for_in_tfdataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#### While 循环"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "while_py_cond() gets unrolled.\n"
    }
   ],
   "source": [
    "@tf.function\n",
    "def while_py_cond():\n",
    "  x = 5\n",
    "  while x > 0:\n",
    "    x -= 1\n",
    "  return x\n",
    "\n",
    "test_dynamically_unrolled(while_py_cond)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "while_tf_cond() uses tf.while_loop.\n"
    }
   ],
   "source": [
    "@tf.function\n",
    "def while_tf_cond():\n",
    "  x = tf.constant(5)\n",
    "  while x > 0:\n",
    "    x -= 1\n",
    "  return x\n",
    "\n",
    "test_dynamically_unrolled(while_tf_cond)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.4 64-bit ('tf': conda)",
   "language": "python",
   "name": "python37464bittfcondac1d40346f2a44515bd7c51c557ecf43c"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}